from Neural_epsilon import Neural_epsilon
from NeuralTS import NeuralTS
from NeuralUCB import NeuralUCBDiag
import argparse
import numpy as np
import sys 

from load_data import load_yelp, load_mnist_adv, load_movielen, Bandit_multi, synthetic


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run baselines')
    parser.add_argument('--dataset', default='mnist', type=str, help='mnist')
    
    parser.add_argument("--method", nargs="+", default=["Neural_epsilon","NeuralUCB","NeuralTS"], help='list: ["Neural_epsilon", "NeuralTS", "NeuralUCB"]')
    
    parser.add_argument('--lamdba', default='0.1', type=float, help='Regulization Parameter')
    parser.add_argument('--nu', default='0.001', type=float, help='Exploration Parameter')
    
    args = parser.parse_args()
    dataset = args.dataset
    arg_lambda = args.lamdba 
    arg_nu = args.nu
    
    print("running methods:", args.method)
    for method in args.method:

        regrets_all = []
        dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']

        
        for d in dataset:
            if d == 'covertype':
                lambda_ = 0.01 
                nu_ = 0.001 
                step = 0.005 
                
                step_ep = 0.005
                ep = 0.05
                
                lambda_ts = 0.01 
                nu_ts = 0.001 
                step_ts = 0.01 
                
            elif d == 'MagicTelescope':
                lambda_ = 0.01 
                nu_ = 0.001 
                step = 0.005
                
                step_ep = 0.01
                ep = 0.05
                
                lambda_ts = 0.01 
                nu_ts = 0.001 
                step_ts = 0.01 
            elif d == 'shuttle':
                lambda_ = 0.01 
                nu_ = 0.001 
                step = 0.005 
                
                
                step_ep = 0.01
                ep = 0.05
                
                lambda_ts = 0.01 
                nu_ts = 0.001 
                step_ts = 0.01 
            
            elif d == 'mushroom':
                lambda_ = 0.01 
                nu_ = 0.001 
                step = 0.005 
                
                
                step_ep = 0.01
                ep = 0.01
                
                lambda_ts = 0.01 
                nu_ts = 0.01 
                step_ts = 0.01 
            elif d == 'fashion':
                lambda_ = 0.01 
                nu_ = 0.001 
                step = 0.005 
                
                step_ep = 0.01
                ep = 0.1
                
                lambda_ts = 0.01 
                nu_ts = 0.01 
                step_ts = 0.01 
            elif d == 'Plants':
                lambda_ = 0.01 
                nu_ = 0.001 
                step = 0.005 
                
                
                step_ep = 0.01
                ep = 0.05
            
                lambda_ts = 0.01 
                nu_ts = 0.001 
                step_ts = 0.01 
            
            for i in range(20):

                if d == 'mnist':
                    b = load_mnist_adv()
                elif d == 'cos' or d == 'square' or d == 'quad':
                    b = synthetic(d)
                else:
                    b = Bandit_multi(d)

                if method == "KernelUCB":
                    model = KernelUCB(b.dim, arg_lambda, arg_nu)

                elif method == "LinUCB":
                    model = Linearucb(b.dim, arg_lambda, arg_nu)

                elif method == "Neural_epsilon":
                    model = Neural_epsilon(b.dim, ep,step)

                elif method == "NeuralTS":
                    model = NeuralTS(b.dim, lamdba = lambda_, nu = nu_, lr = step, hidden = 100)

                elif method == "NeuralUCB":
                    model = NeuralUCBDiag(b.dim, lamdba = lambda_, nu = nu_, lr = step,  hidden = 100)

                elif method == "NeuralNoExplore":
                    model = NeuralNoExplore(b.dim)
                else:
                    print("method is not defined. --help")
                    sys.exit()

                regrets = []
                sum_regret = 0
                print("Round; Regret; Regret/Round")
                block = 500
                error = np.zeros(b.n_arm)
                count = np.zeros(b.n_arm)
                for t in range(5000):
                    
                    '''Draw input sample'''
                    if t < block:
                        context, rwd, arm = b.step(-1)
                    elif t%block == 0:
                        k = np.argmax(error/count)
                        print(error,count,error/count,k)
                        context, rwd, arm = b.step(k)
                    else:
                        context, rwd, arm = b.step(k)
                    arm_select = model.select(context)
                    reward = rwd[arm_select]
                    count[arm] +=1
                    if reward==0:
                        error[arm] += 1
                    if method == "LinUCB" or method == "KernelUCB":
                        model.train(context[arm_select],reward)

                    elif method == "Neural_epsilon" or method == "NeuralUCB" or method == "NeuralNoExplore":
                        model.update(context[arm_select], reward)
                        if t<1000:
                            if t%10 == 0:
                                loss = model.train(t)
                        else:
                            if t%100 == 0:
                                loss = model.train(t)
                    elif method == "NeuralTS":
                        if t<1000:
                            if t%10 == 0:
                                loss = model.train(context[arm_select], reward,True)
                            else:
                                loss = model.train(context[arm_select], reward,False)
                        else:
                            if t%100 == 0:
                                loss = model.train(context[arm_select], reward,True)
                            else:
                                loss = model.train(context[arm_select], reward,False)

                    regret = np.max(rwd) - reward
                    sum_regret+=regret
                    regrets.append(sum_regret)
                    if t % 50 == 0:
                        print('{}: {:}, {:.4f}'.format(t, sum_regret, sum_regret/(t+1)))

                print("run:", i, "; ", "regret:", sum_regret)
                regrets_all.append(regrets)
            np.save("./results/iclr_results/NeuralEps/{}_regret_{}".format(method,d), regrets_all)

    

    

    

    

    

    

